{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WrcIOXsUQh8U"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "tXAbWHtqs1Y2"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HTgMAvQq-PU_"
      },
      "source": [
        "# Extension types\n",
        "\n",
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/extension_type\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/guide/extension_type.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/guide/extension_type.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/extension_type.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jHcw9MtgBo7e"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0MsE_F0WBpmc"
      },
      "outputs": [],
      "source": [
        "!pip install -q tf_nightly\n",
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "from typing import Tuple, List, Mapping, Union, Optional\n",
        "import tempfile"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1BAk3bji_0wl"
      },
      "source": [
        "## Extension types\n",
        "\n",
        "User-defined types can make projects more readable, modular, maintainable. However, most TensorFlow APIs have very limited support for user-defined Python types. This includes both high-level APIs (such as [Keras](https://www.tensorflow.org/guide/keras/overview), [tf.function](https://www.tensorflow.org/guide/function), [`tf.SavedModel`](https://www.tensorflow.org/guide/saved_model)) and lower-level APIs (such as `tf.while_loop` and `tf.concat`). TensorFlow **extension types** can be used to create user-defined object-oriented types that work seamlessly with TensorFlow's APIs. To create an extension type, simply define a Python class with `tf.experimental.ExtensionType` as its base, and use [type annotations](https://www.python.org/dev/peps/pep-0484/) to specify the type for each field."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7o5KY7L5_nxy"
      },
      "outputs": [],
      "source": [
        "class TensorGraph(tf.experimental.ExtensionType):\n",
        "  \"\"\"A collection of labeled nodes connected by weighted edges.\"\"\"\n",
        "  edge_weights: tf.Tensor               # shape=[num_nodes, num_nodes]\n",
        "  node_labels: Mapping[str, tf.Tensor]  # shape=[num_nodes]; dtype=any\n",
        "\n",
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  \"\"\"A tensor paired with a boolean mask, indicating which values are valid.\"\"\"\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor       # shape=values.shape; false for missing/invalid values.\n",
        "\n",
        "class CSRSparseMatrix(tf.experimental.ExtensionType):\n",
        "  \"\"\"Compressed sparse row matrix (https://en.wikipedia.org/wiki/Sparse_matrix).\"\"\"\n",
        "  values: tf.Tensor     # shape=[num_nonzero]; dtype=any\n",
        "  col_index: tf.Tensor  # shape=[num_nonzero]; dtype=int64\n",
        "  row_index: tf.Tensor  # shape=[num_rows+1]; dtype=int64"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FiaNXPa7pNK-"
      },
      "source": [
        "The `tf.experimental.ExtensionType` base class works similarly to [`typing.NamedTuple`](https://docs.python.org/3/library/typing.html#typing.NamedTuple) and [`@dataclasses.dataclass`](https://docs.python.org/3/library/dataclasses.html#dataclasses.dataclass) from the standard Python library. In particular, it automatically adds a constructor and special methods (such as `__repr__` and `__eq__`) based on the field type annotations."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JsE7X6_uMyLo"
      },
      "source": [
        "Typically, extension types tend to fall into one of two categories:\n",
        "\n",
        "* ***Data structures***, which group together a collection of related values, and can provide useful operations based on those values. Data structures may be fairly general (such as the `TensorGraph` example above); or they may be highly customized to a specific model.\n",
        "\n",
        "* ***Tensor-like types***, which specialize or extend the concept of \"Tensor.\" Types in this category have a `rank`, a `shape`, and usually a `dtype`; and it makes sense to use them with Tensor operations (such as `tf.stack`, `tf.add`, or `tf.matmul`). `MaskedTensor` and `CSRSparseMatrix` are examples of tensor-like types."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uxngcajlMqIY"
      },
      "source": [
        "## Supported APIs\n",
        "\n",
        "Extension types are supported by the following TensorFlow APIs:\n",
        "\n",
        "* **Keras**: Extension types can be used as inputs and outputs for Keras `Models` and `Layers`.\n",
        "* **`tf.data.Dataset`**: Extension types can be included in `Datasets`, and returned by dataset `Iterators`.\n",
        "* **TensorFlow Hub**: Extension types can be used as inputs and outputs for `tf.hub` modules.\n",
        "* **SavedModel**: Extension types can be used as inputs and outputs for `SavedModel` functions.\n",
        "* **`tf.function`**: Extension types can be used as arguments and return values for functions wrapped with the `@tf.function` decorator.\n",
        "* **While loops**: Extension types can be used as loop variables in `tf.while_loop`, and can be used as arguments and return values for the while-loop's body.\n",
        "* **Conditionals**: Extension types can be conditionally selected using `tf.cond` and `tf.case`.\n",
        "* **`tf.py_function`**: Extension types can be used as arguments and return values for the `func` argument to `tf.py_function`.\n",
        "* **Tensor ops**: Extension types can be extended to support most TensorFlow ops that accept Tensor inputs (such as `tf.matmul`, `tf.gather`, and `tf.reduce_sum`). Go to the \"*Dispatch*\" section below for more information.\n",
        "* **Distribution strategy**: Extension types can be used as per-replica values.\n",
        "\n",
        "For more details, see the section on \"TensorFlow APIs that support ExtensionTypes\" below.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VIpZwuPVpwOX"
      },
      "source": [
        "## Requirements\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nNk_TQeJGVwV"
      },
      "source": [
        "### Field types\n",
        "\n",
        "All fields—instance variables—must be declared, and a type annotation must be provided for each field. The following type annotations are supported:\n",
        "\n",
        "Type | Example\n",
        "---- | -------\n",
        "Python integers | `i: int`\n",
        "Python floats | `f: float`\n",
        "Python strings | `s: str`\n",
        "Python booleans | `b: bool`\n",
        "Python `None` | `n: None`\n",
        "[Tensor shapes](https://www.tensorflow.org/api_docs/python/tf/TensorShape) | `shape: tf.TensorShape`\n",
        "[Tensor `dtype`s](https://www.tensorflow.org/api_docs/python/tf/dtypes/DType) | `dtype: tf.DType`\n",
        "[Tensors](https://www.tensorflow.org/api_docs/python/tf/Tensor) | `t: tf.Tensor`\n",
        "[Extension types](https://www.tensorflow.org/api_docs/python/tf/experimental/ExtensionType) | `mt: MyMaskedTensor`\n",
        "[Ragged tensors](https://www.tensorflow.org/api_docs/python/tf/RaggedTensor) | `rt: tf.RaggedTensor`\n",
        "[Sparse tensors](https://www.tensorflow.org/api_docs/python/tf/sparse/SparseTensor) | `st: tf.SparseTensor`\n",
        "[Indexed slices](https://www.tensorflow.org/api_docs/python/tf/IndexedSlices) | `s: tf.IndexedSlices`\n",
        "[Optional tensors](https://www.tensorflow.org/api_docs/python/tf/experimental/Optional) | `o: tf.experimental.Optional`\n",
        "[Type unions](https://docs.python.org/3/library/typing.html#typing.Union) | `int_or_float: typing.Union[int, float]`\n",
        "[Tuples](https://docs.python.org/3/library/typing.html#typing.Tuple) | `params: typing.Tuple[int, float, tf.Tensor, int]`\n",
        "[Var-length tuples](https://docs.python.org/3/library/typing.html#typing.Tuple) | `lengths: typing.Tuple[int, ...]`\n",
        "[Mappings](https://docs.python.org/3/library/typing.html#typing.Mapping) | `tags: typing.Mapping[str, tf.Tensor]`\n",
        "[Optional values](https://docs.python.org/3/library/typing.html#typing.Optional) | `weight: typing.Optional[tf.Tensor]`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iFetYyZsIvf6"
      },
      "source": [
        "### Mutability\n",
        "\n",
        "Extension types are required to be immutable. This ensures that they can be properly tracked by TensorFlow's graph-tracing mechanisms.\n",
        "If you find yourself wanting to mutate an extension type value, consider instead defining methods that transform values. For example, rather than defining a `set_mask` method to mutate a `MaskedTensor`, you could define a `replace_mask` method that returns a new `MaskedTensor`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DThZLYH2IwFh"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor\n",
        "\n",
        "  def replace_mask(self, new_mask):\n",
        "      self.values.shape.assert_is_compatible_with(new_mask.shape)\n",
        "      return MaskedTensor(self.values, new_mask)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "x3JyivI_qAtt"
      },
      "source": [
        "## Functionality added by `ExtensionType`\n",
        "\n",
        "The `ExtensionType` base class provides the following functionality:\n",
        "\n",
        "* A constructor (`__init__`).\n",
        "* A printable representation method (`__repr__`).\n",
        "* Equality and inequality operators (`__eq__`).\n",
        "* A validation method (`__validate__`).\n",
        "* Enforced immutability.\n",
        "* A nested `TypeSpec`.\n",
        "* Tensor API dispatch support.\n",
        "\n",
        "Go to the \"Customizing `ExtensionType`s\" section below for more information on customizing this functionality."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pfSYs6P26gKq"
      },
      "source": [
        "### Constructor\n",
        "The constructor added by `ExtensionType` takes each field as a named argument (in the order they were listed in the class definition). This constructor will type-check each parameter, and convert them where necessary. In particular, `Tensor` fields are converted using `tf.convert_to_tensor`; `Tuple` fields are converted to `tuple`s; and `Mapping` fields are converted to immutable dicts."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DiXwyZ5M5KFW"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor\n",
        "\n",
        "# Constructor takes one parameter for each field.\n",
        "mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],\n",
        "                  mask=[[True, True, False], [True, False, True]])\n",
        "\n",
        "# Fields are type-checked and converted to the declared types.\n",
        "# For example, `mt.values` is converted to a Tensor.\n",
        "print(mt.values)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ezNDe1cYF0Qb"
      },
      "source": [
        "The constructor raises an `TypeError` if a field value can not be converted to its declared type:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6HnrMaabF5VS"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  MaskedTensor([1, 2, 3], None)\n",
        "except TypeError as e:\n",
        "  print(f\"Got expected TypeError: {e}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FwQUI3X02s20"
      },
      "source": [
        "The default value for a field can be specified by setting its value at the class level:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GbzDT9fz20JA"
      },
      "outputs": [],
      "source": [
        "class Pencil(tf.experimental.ExtensionType):\n",
        "  color: str = \"black\"\n",
        "  has_erasor: bool = True\n",
        "  length: tf.Tensor = 1.0\n",
        "\n",
        "Pencil()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "nOW7lS9P4Foc"
      },
      "outputs": [],
      "source": [
        "Pencil(length=0.5, color=\"blue\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "S5Eivtg07Aau"
      },
      "source": [
        "### Printable representation\n",
        "\n",
        "`ExtensionType` adds a default printable representation method (`__repr__`) that includes the class name and the value for each field:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5SyiKTe55krG"
      },
      "outputs": [],
      "source": [
        "print(MaskedTensor(values=[1, 2, 3], mask=[True, True, False]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "q4l_gnQh6nXR"
      },
      "source": [
        "### Equality operators\n",
        "\n",
        "`ExtensionType` adds default equality operators (`__eq__` and `__ne__`) that consider two values equal if they have the same type and all their fields are equal. Tensor fields are considered equal if they have the same shape and are elementwise equal for all elements."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bHdLg13V52Xm"
      },
      "outputs": [],
      "source": [
        "a = MaskedTensor([1, 2], [True, False])\n",
        "b = MaskedTensor([[3, 4], [5, 6]], [[False, True], [True, True]])\n",
        "print(f\"a == a: {a==a}\")\n",
        "print(f\"a == b: {a==b}\")\n",
        "print(f\"a == a.values: {a==a.values}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "O3HqsO3jZlQq"
      },
      "source": [
        "**Note:** if any field contains a `Tensor`, then `__eq__` may return a scalar boolean `Tensor` (rather than a Python boolean value)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hCpBfkKqCuip"
      },
      "source": [
        "### Validation method\n",
        "\n",
        "`ExtensionType` adds a `__validate__` method, which can be overridden to perform validation checks on fields. It is run after the constructor is called, and after fields have been type-checked and converted to their declared types, so it can assume that all fields have their declared types.\n",
        "\n",
        "The following example updates `MaskedTensor` to validate the `shape`s and `dtype`s of its fields:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dgZOJRINDn00"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  \"\"\"A tensor paired with a boolean mask, indicating which values are valid.\"\"\"\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor\n",
        "  def __validate__(self):\n",
        "    self.values.shape.assert_is_compatible_with(self.mask.shape)\n",
        "    assert self.mask.dtype.is_bool, 'mask.dtype must be bool'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ajSgkGUUn9WL"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  MaskedTensor([1, 2, 3], [0, 1, 0])  # Wrong `dtype` for mask.\n",
        "except AssertionError as e:\n",
        "  print(f\"Got expected AssertionError: {e}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Fhb96luJn9K7"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  MaskedTensor([1, 2, 3], [True, False])  # shapes don't match.\n",
        "except ValueError as e:\n",
        "  print(f\"Got expected ValueError: {e}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pjIPAF1OCAdO"
      },
      "source": [
        "### Enforced immutability\n",
        "\n",
        "`ExtensionType` overrides the `__setattr__` and `__delattr__` methods to prevent mutation, ensuring that extension type values are immutable."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NgmJ1C7ilN5C"
      },
      "outputs": [],
      "source": [
        "mt = MaskedTensor([1, 2, 3], [True, False, True])"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cMYmJr3RoFKp"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  mt.mask = [True, True, True]\n",
        "except AttributeError as e:\n",
        "  print(f\"Got expected AttributeError: {e}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ZWwA-zWdzqlU"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  mt.mask[0] = False\n",
        "except TypeError as e:\n",
        "  print(f\"Got expected TypeError: {e}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PN_txJVKoFoF"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  del mt.mask\n",
        "except AttributeError as e:\n",
        "  print(f\"Got expected AttributeError: {e}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FBVFtCYn69Ou"
      },
      "source": [
        "### Nested TypeSpec\n",
        "\n",
        "Each `ExtensionType` class has a corresponding `TypeSpec` class, which is created automatically and stored as `<extension_type_name>.Spec`.\n",
        "\n",
        "This class captures all the information from a value *except* for the values of any nested tensors. In particular, the `TypeSpec` for a value is created by replacing any nested Tensor, ExtensionType, or CompositeTensor with its `TypeSpec`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GRjANkGYKGnV"
      },
      "outputs": [],
      "source": [
        "class Player(tf.experimental.ExtensionType):\n",
        "  name: tf.Tensor\n",
        "  attributes: Mapping[str, tf.Tensor]\n",
        "\n",
        "anne = Player(\"Anne\", {\"height\": 8.3, \"speed\": 28.1})\n",
        "anne_spec = tf.type_spec_from_value(anne)\n",
        "print(anne_spec.name)  # Records `dtype` and `shape`, but not the string value.\n",
        "print(anne_spec.attributes)  # Records keys and TensorSpecs for values."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I2fkgckxO564"
      },
      "source": [
        "`TypeSpec` values can be constructed explicitly, or they can be built from an `ExtensionType` value using `tf.type_spec_from_value`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1ehAa7d9OGai"
      },
      "outputs": [],
      "source": [
        "spec1 = Player.Spec(name=tf.TensorSpec([], tf.float32), attributes={})\n",
        "spec2 = tf.type_spec_from_value(anne)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "owcFG3cAMCwA"
      },
      "source": [
        "`TypeSpec`s are used by TensorFlow to divide values into a **static component** and a **dynamic component**:\n",
        "\n",
        "* The **static component** (which is fixed at graph-construction time) is encoded with a `tf.TypeSpec`.\n",
        "* The **dynamic component** (which can vary each time the graph is run) is encoded as a list of `tf.Tensor`s.\n",
        "\n",
        "For example, `tf.function` retraces its wrapped function whenever an argument has a previously unseen `TypeSpec`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pg-m5YLRM1Nd"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def anonymize_player(player):\n",
        "  print(\"<<TRACING>>\")\n",
        "  return Player(\"<anonymous>\", player.attributes)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0CCGm7cpeIq-"
      },
      "outputs": [],
      "source": [
        "# Function gets traced (first time the function has been called):\n",
        "anonymize_player(Player(\"Anne\", {\"height\": 8.3, \"speed\": 28.1}))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "WB7bt7s83mFE"
      },
      "outputs": [],
      "source": [
        "# Function does NOT get traced (same TypeSpec: just tensor values changed)\n",
        "anonymize_player(Player(\"Bart\", {\"height\": 8.1, \"speed\": 25.3}))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dNm7vLpR3nMH"
      },
      "outputs": [],
      "source": [
        "# Function gets traced (new TypeSpec: keys for attributes changed):\n",
        "anonymize_player(Player(\"Chuck\", {\"height\": 11.0, \"jump\": 5.3}))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U5rN1HPq25xC"
      },
      "source": [
        "For more information, see the [tf.function Guide](https://www.tensorflow.org/guide/function#rules_of_tracing)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gX613uRk0qLz"
      },
      "source": [
        "## Customizing `ExtensionType`s\n",
        "\n",
        "In addition to simply declaring fields and their types, extension types may:\n",
        "\n",
        "* Override the default printable representation (`__repr__`).\n",
        "* Define methods.\n",
        "* Define `classmethod`s and `staticmethod`s.\n",
        "* Define properties.\n",
        "* Override the default constructor (`__init__`).\n",
        "* Override the default equality operator (`__eq__`).\n",
        "* Define operators (such as `__add__` and `__lt__`).\n",
        "* Declare default values for fields.\n",
        "* Define subclasses.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MK-ePVDj-ROE"
      },
      "source": [
        "### Overriding the default printable representation\n",
        "\n",
        "You can override this default string conversion operator for extension types. The following example updates the `MaskedTensor` class to generate a more readable string representation when values are printed in Eager mode."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gdPhjYEr8IGO"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  \"\"\"A tensor paired with a boolean mask, indicating which values are valid.\"\"\"\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor       # shape=values.shape; false for invalid values.\n",
        "\n",
        "  def __repr__(self):\n",
        "    return masked_tensor_str(self.values, self.mask)\n",
        "\n",
        "def masked_tensor_str(values, mask):\n",
        "  if isinstance(values, tf.Tensor):\n",
        "    if hasattr(values, 'numpy') and hasattr(mask, 'numpy'):\n",
        "      return f'<MaskedTensor {masked_tensor_str(values.numpy(), mask.numpy())}>'\n",
        "    else:\n",
        "      return f'MaskedTensor(values={values}, mask={mask})'\n",
        "  if len(values.shape) == 1:\n",
        "    items = [repr(v) if m else '_' for (v, m) in zip(values, mask)]\n",
        "  else:\n",
        "    items = [masked_tensor_str(v, m) for (v, m) in zip(values, mask)]\n",
        "  return '[%s]' % ', '.join(items)\n",
        "\n",
        "mt = MaskedTensor(values=[[1, 2, 3], [4, 5, 6]],\n",
        "                  mask=[[True, True, False], [True, False, True]])\n",
        "print(mt)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_MLQU2_v8VjG"
      },
      "source": [
        "### Defining methods\n",
        "\n",
        "Extension types may define methods, just like any normal Python class. For example, the `MaskedTensor` type could define a `with_default` method that returns a copy of `self` with masked values replaced by a given `default` value. Methods may optionally be annotated with the `@tf.function` decorator."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7RR-tqee8ZdP"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor\n",
        "\n",
        "  def with_default(self, default):\n",
        "    return tf.where(self.mask, self.values, default)\n",
        "\n",
        "MaskedTensor([1, 2, 3], [True, False, True]).with_default(0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Qwd_gGKp9RP0"
      },
      "source": [
        "### Defining `classmethod`s and `staticmethod`s\n",
        "\n",
        "Extension types may define methods using the `@classmethod` and `@staticmethod` decorators. For example, the `MaskedTensor` type could define a factory method that masks any element with a given value:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BacCEJYU9sBR"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor\n",
        "\n",
        "  def __repr__(self):\n",
        "    return masked_tensor_str(self.values, self.mask)\n",
        "\n",
        "  @staticmethod\n",
        "  def from_tensor_and_value_to_mask(values, value_to_mask):\n",
        "    return MaskedTensor(values, values != value_to_mask)\n",
        "\n",
        "x = tf.constant([[1, 0, 2], [3, 0, 0]])\n",
        "MaskedTensor.from_tensor_and_value_to_mask(x, 0)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xIPf9PZX9AwL"
      },
      "source": [
        "### Defining properties\n",
        "Extension types may define properties using the `@property` decorator, just like any normal Python class. For example, the `MaskedTensor` type could define a `dtype` property that's a shorthand for the `dtype` of the values:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "16E68wZ-9KXp"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor\n",
        "\n",
        "  @property\n",
        "  def dtype(self):\n",
        "    return self.values.dtype\n",
        "\n",
        "MaskedTensor([1, 2, 3], [True, False, True]).dtype"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Mm5gxoG57nf3"
      },
      "source": [
        "### Overriding the default constructor\n",
        "\n",
        "You can override the default constructor for extension types. Custom constructors must set a value for every declared field; and after the custom constructor returns, all fields will be type-checked, and values will be converted as described above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-8K3KeB08G1S"
      },
      "outputs": [],
      "source": [
        "class Toy(tf.experimental.ExtensionType):\n",
        "  name: str\n",
        "  price: tf.Tensor\n",
        "  def __init__(self, name, price, discount=0):\n",
        "    self.name = name\n",
        "    self.price = price * (1 - discount)\n",
        "\n",
        "print(Toy(\"ball\", 5.0, discount=0.2))  # On sale -- 20% off!"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qyQxMlwLFQt7"
      },
      "source": [
        "Alternatively, you might consider leaving the default constructor as-is, but adding one or more factory methods. For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jiApK4hzFY89"
      },
      "outputs": [],
      "source": [
        "class Toy(tf.experimental.ExtensionType):\n",
        "  name: str\n",
        "  price: tf.Tensor\n",
        "\n",
        "  @staticmethod\n",
        "  def new_toy_with_discount(name, price, discount):\n",
        "    return Toy(name, price * (1 - discount))\n",
        "\n",
        "print(Toy.new_toy_with_discount(\"ball\", 5.0, discount=0.2))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "pdVcRBhG-Uee"
      },
      "source": [
        "### Overriding the default equality operator (`__eq__`)\n",
        "\n",
        "You can override the default `__eq__` operator for extension types. The following example updates `MaskedTensor` to ignore masked elements when comparing for equality."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "dA7DyjfB-Yz0"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor\n",
        "\n",
        "  def __repr__(self):\n",
        "    return masked_tensor_str(self.values, self.mask)\n",
        "\n",
        "  def __eq__(self, other):\n",
        "    result = tf.math.equal(self.values, other.values)\n",
        "    result = result | ~(self.mask & other.mask)\n",
        "    return tf.reduce_all(result)\n",
        "\n",
        "x = MaskedTensor([1, 2, 3, 4], [True, True, False, True])\n",
        "y = MaskedTensor([5, 2, 0, 4], [False, True, False, True])\n",
        "print(x == y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "n1mZ1Lkyi14B"
      },
      "source": [
        "**Note:** You generally don't need to override `__ne__`, since its default implementation simply calls `__eq__` and negates the result."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A_Jib1SQD1-z"
      },
      "source": [
        "### Using forward references\n",
        "\n",
        "If the type for a field has not been defined yet, you may use a string containing the name of the type instead. In the following example, the string `\"Node\"` is used to annotate the `children` field because the `Node` type hasn't been (fully) defined yet.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_Z029QKED0Ao"
      },
      "outputs": [],
      "source": [
        "class Node(tf.experimental.ExtensionType):\n",
        "  value: tf.Tensor\n",
        "  children: Tuple[\"Node\", ...] = ()\n",
        "\n",
        "Node(3, [Node(5), Node(2)])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "boaNg1zHgoVn"
      },
      "source": [
        "### Defining subclasses\n",
        "\n",
        "Extension types may be subclassed using the standard Python syntax. Extension type subclasses may add new fields, methods, and properties; and may override the constructor, the printable representation, and the equality operator. The following example defines a basic `TensorGraph` class that uses three `Tensor` fields to encode a set of edges between nodes. It then defines a subclass that adds a `Tensor` field to record a \"feature value\" for each node. The subclass also defines a method to propagate the feature values along the edges."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "58r6qRiK-uZh"
      },
      "outputs": [],
      "source": [
        "class TensorGraph(tf.experimental.ExtensionType):\n",
        "  num_nodes: tf.Tensor\n",
        "  edge_src: tf.Tensor   # edge_src[e] = index of src node for edge e.\n",
        "  edge_dst: tf.Tensor   # edge_dst[e] = index of dst node for edge e.\n",
        "\n",
        "class TensorGraphWithNodeFeature(TensorGraph):\n",
        "  node_features: tf.Tensor  # node_features[n] = feature value for node n.\n",
        "\n",
        "  def propagate_features(self, weight=1.0) -> 'TensorGraphWithNodeFeature':\n",
        "    updates = tf.gather(self.node_features, self.edge_src) * weight\n",
        "    new_node_features = tf.tensor_scatter_nd_add(\n",
        "        self.node_features, tf.expand_dims(self.edge_dst, 1), updates)\n",
        "    return TensorGraphWithNodeFeature(\n",
        "        self.num_nodes, self.edge_src, self.edge_dst, new_node_features)\n",
        "\n",
        "g = TensorGraphWithNodeFeature(  # Edges: 0->1, 4->3, 2->2, 2->1\n",
        "    num_nodes=5, edge_src=[0, 4, 2, 2], edge_dst=[1, 3, 2, 1],\n",
        "    node_features=[10.0, 0.0, 2.0, 5.0, -1.0, 0.0])\n",
        "\n",
        "print(\"Original features:\", g.node_features)\n",
        "print(\"After propagating:\", g.propagate_features().node_features)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U_oElT5HzqSG"
      },
      "source": [
        "### Defining private fields\n",
        "\n",
        "An extension type's fields may be marked private by prefixing them with an underscore (following standard Python conventions). This does not impact the way that TensorFlow treats the fields in any way; but simply serves as a signal to any users of the extension type that those fields are private.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oMdH7ORqh8Pl"
      },
      "source": [
        "### Customizing the `ExtensionType`'s `TypeSpec`\n",
        "\n",
        "Each `ExtensionType` class has a corresponding `TypeSpec` class, which is created automatically and stored as `<extension_type_name>.Spec`. For more information, see the section \"Nested TypeSpec\" above.\n",
        "\n",
        "To customize the `TypeSpec`, simply define your own nested class named `Spec`, and `ExtensionType` will use that as the basis for the automatically constructed `TypeSpec`. You can customize the `Spec` class by:\n",
        "\n",
        "* Overriding the default printable representation.\n",
        "* Overriding the default constructor.\n",
        "* Defining methods, `classmethod`s, `staticmethod`s, and properties.\n",
        "\n",
        "The following example customizes the `MaskedTensor.Spec` class to make it easier to use:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gm4RaqbkLlNG"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.ExtensionType):\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor\n",
        "\n",
        "  shape = property(lambda self: self.values.shape)\n",
        "  dtype = property(lambda self: self.values.dtype)\n",
        "\n",
        "  def __repr__(self):\n",
        "    return masked_tensor_str(self.values, self.mask)\n",
        "\n",
        "  def with_values(self, new_values):\n",
        "    return MaskedTensor(new_values, self.mask)\n",
        "\n",
        "  class Spec:\n",
        "    def __init__(self, shape, dtype=tf.float32):\n",
        "      self.values = tf.TensorSpec(shape, dtype)\n",
        "      self.mask = tf.TensorSpec(shape, tf.bool)\n",
        "\n",
        "    def __repr__(self):\n",
        "      return f\"MaskedTensor.Spec(shape={self.shape}, dtype={self.dtype})\"\n",
        "\n",
        "    shape = property(lambda self: self.values.shape)\n",
        "    dtype = property(lambda self: self.values.dtype)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "s3zzUXPSNF72"
      },
      "source": [
        "**Note**: The custom `Spec` class may not use any instance variables that were not declared in the original `ExtensionType`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rip4GCuYPL7o"
      },
      "source": [
        "## Tensor API dispatch\n",
        "\n",
        "Extension types can be \"tensor-like\", in the sense that they specialize or extend the interface defined by the `tf.Tensor` type. Examples of tensor-like extension types include `RaggedTensor`, `SparseTensor`, and `MaskedTensor`. ***Dispatch decorators*** can be used to override the default behavior of TensorFlow operations when applied to tensor-like extension types. TensorFlow currently defines three dispatch decorators:\n",
        "\n",
        "* `@tf.experimental.dispatch_for_api(tf_api)`\n",
        "* `@tf.experimental.dispatch_for_unary_elementwise_apis(x_type)`\n",
        "* `@tf.experimental.dispatch_for_binary_elementwise_apis(x_type, y_type)`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5BTQHcY4gHwZ"
      },
      "source": [
        "### Dispatch for a single API\n",
        "\n",
        "The `tf.experimental.dispatch_for_api` decorator overrides the default behavior of a specified TensorFlow operation when it is called with the specified signature. For example, you can use this decorator to specify how `tf.stack` should process `MaskedTensor` values:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "B4QgO_fUW2o2"
      },
      "outputs": [],
      "source": [
        "@tf.experimental.dispatch_for_api(tf.stack)\n",
        "def masked_stack(values: List[MaskedTensor], axis = 0):\n",
        "  return MaskedTensor(tf.stack([v.values for v in values], axis),\n",
        "                      tf.stack([v.mask for v in values], axis))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FxKcKWNUaLvm"
      },
      "source": [
        "This overrides the default implementation for `tf.stack` whenever it is called with a list of `MaskedTensor` values (since the `values` argument is annotated with `typing.List[MaskedTensor]`):"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RqpFjaAvaA19"
      },
      "outputs": [],
      "source": [
        "x = MaskedTensor([1, 2, 3], [True, True, False])\n",
        "y = MaskedTensor([4, 5, 6], [False, True, True])\n",
        "tf.stack([x, y])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "loGi8taCa265"
      },
      "source": [
        "To allow `tf.stack` to handle lists of mixed `MaskedTensor` and `Tensor` values,  you can refine the type annotation for the `values` parameter and update the body of the function appropriately:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_xySkm0ganAI"
      },
      "outputs": [],
      "source": [
        "tf.experimental.unregister_dispatch_for(masked_stack)\n",
        "\n",
        "def convert_to_masked_tensor(x):\n",
        "  if isinstance(x, MaskedTensor):\n",
        "    return x\n",
        "  else:\n",
        "    return MaskedTensor(x, tf.ones_like(x, tf.bool))\n",
        "\n",
        "@tf.experimental.dispatch_for_api(tf.stack)\n",
        "def masked_stack_v2(values: List[Union[MaskedTensor, tf.Tensor]], axis = 0):\n",
        "  values = [convert_to_masked_tensor(v) for v in values]\n",
        "  return MaskedTensor(tf.stack([v.values for v in values], axis),\n",
        "                      tf.stack([v.mask for v in values], axis))\n",
        "x = MaskedTensor([1, 2, 3], [True, True, False])\n",
        "y = tf.constant([4, 5, 6])\n",
        "tf.stack([x, y, x])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ITioFCyjQm8V"
      },
      "source": [
        "For a list of APIs that can be overridden, see the API documentation for `tf.experimental.dispatch_for_api`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f91SaHSqc-jO"
      },
      "source": [
        "### Dispatch for all unary elementwise APIs\n",
        "\n",
        "The `tf.experimental.dispatch_for_unary_elementwise_apis` decorator overrides the default behavior of ***all*** unary elementwise ops (such as `tf.math.cos`) whenever the value for the first argument (typically named `x`) matches the type annotation `x_type`. The decorated function should take two arguments:\n",
        "\n",
        "* `api_func`: A function that takes a single parameter and performs the elementwise operation (for example, `tf.abs`).\n",
        "* `x`: The first argument to the elementwise operation.\n",
        "\n",
        "The following example updates all unary elementwise operations to handle the `MaskedTensor` type:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "cv5fV4xxZI9q"
      },
      "outputs": [],
      "source": [
        " @tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)\n",
        " def masked_tensor_unary_elementwise_api_handler(api_func, x):\n",
        "   return MaskedTensor(api_func(x.values), x.mask)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qiK4n6vaeFwo"
      },
      "source": [
        "This function will now be used whenever a unary elementwise operation is called on a `MaskedTensor`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SkH0xi5gd_41"
      },
      "outputs": [],
      "source": [
        " x = MaskedTensor([1, -2, -3], [True, False, True])\n",
        " print(tf.abs(x))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2Ej5fxLBfaXW"
      },
      "outputs": [],
      "source": [
        "print(tf.ones_like(x, dtype=tf.float32))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Z9OgLyfEejqc"
      },
      "source": [
        "### Dispatch for binary all elementwise APIs\n",
        "\n",
        "Similarly, `tf.experimental.dispatch_for_binary_elementwise_apis` can be used to update all binary elementwise operations to handle the `MaskedTensor` type:\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z8Du-GPofpCW"
      },
      "outputs": [],
      "source": [
        "@tf.experimental.dispatch_for_binary_elementwise_apis(MaskedTensor, MaskedTensor)\n",
        "def masked_tensor_binary_elementwise_api_handler(api_func, x, y):\n",
        "  return MaskedTensor(api_func(x.values, y.values), x.mask & y.mask)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "gghVHDfSfyi2"
      },
      "outputs": [],
      "source": [
        "x = MaskedTensor([1, -2, -3], [True, False, True])\n",
        "y = MaskedTensor([[4], [5]], [[True], [False]])\n",
        "tf.math.add(x, y)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "txTGg9pzG0Ux"
      },
      "source": [
        "For a list of the elementwise APIs that are overridden, go to the API documentation for `tf.experimental.dispatch_for_unary_elementwise_apis` and `tf.experimental.dispatch_for_binary_elementwise_apis`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UseRtohYKiE5"
      },
      "source": [
        "## Batchable `ExtensionType`s\n",
        "\n",
        "An `ExtensionType` is *batchable* if a single instance can be used to represent a batch of values. Typically, this is accomplished by adding batch dimensions to all nested `Tensor`s. The following TensorFlow APIs require that any extension type inputs be batchable:\n",
        "\n",
        "* `tf.data.Dataset` (`batch`, `unbatch`, `from_tensor_slices`)\n",
        "* `tf.keras` (`fit`, `evaluate`, `predict`)\n",
        "* `tf.map_fn`"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hWPauKGj_yRz"
      },
      "source": [
        "By default, `BatchableExtensionType` creates batched values by batching any nested `Tensor`s, `CompositeTensor`s, and `ExtensionType`s. If this is not appropriate for your class, then you will need to use `tf.experimental.ExtensionTypeBatchEncoder` to override this default behavior. For example, it would not be appropriate to create a batch of `tf.SparseTensor` values by simply stacking individual sparse tensors' `values`, `indices`, and `dense_shape` fields -- in most cases, you can't stack these tensors, since they have incompatible shapes; and even if you could, the result would not be a valid `SparseTensor`.\n",
        "\n",
        "\n",
        "**Note**: `BatchableExtensionType`s do *not* automatically define dispatchers for `tf.stack`, `tf.concat`, `tf.slice`, etc. If your class needs to be supported by these APIs, then use the dispatch decorators described above."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "xkOJ8ke8GH7s"
      },
      "source": [
        "### `BatchableExtensionType` example: `Network`\n",
        "As an example, consider a simple `Network` class used for load balancing, which tracks how much work is left to do at each node, and how much bandwidth is available to move work between nodes:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tOeEXwCcfrPd"
      },
      "outputs": [],
      "source": [
        "class Network(tf.experimental.ExtensionType):  # This version is not batchable.\n",
        "  work: tf.Tensor       # work[n] = work left to do at node n\n",
        "  bandwidth: tf.Tensor  # bandwidth[n1, n2] = bandwidth from n1->n2\n",
        "\n",
        "net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])\n",
        "net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PaOzUev6g3wT"
      },
      "source": [
        "To make this type batchable, change the base type to `BatchableExtensionType`, and adjust the shape of each field to include optional batch dimensions. The following example also adds a `shape` field to keep track of the batch shape. This `shape` field is not required by `tf.data.Dataset` or `tf.map_fn`, but it *is* required by `tf.keras`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T03WWBSMg2XC"
      },
      "outputs": [],
      "source": [
        "class Network(tf.experimental.BatchableExtensionType):\n",
        "  shape: tf.TensorShape  # batch shape. A single network has shape=[].\n",
        "  work: tf.Tensor        # work[*shape, n] = work left to do at node n\n",
        "  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2\n",
        "\n",
        "  def __init__(self, work, bandwidth):\n",
        "    self.work = tf.convert_to_tensor(work)\n",
        "    self.bandwidth = tf.convert_to_tensor(bandwidth)\n",
        "    work_batch_shape = self.work.shape[:-1]\n",
        "    bandwidth_batch_shape = self.bandwidth.shape[:-2]\n",
        "    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)\n",
        "\n",
        "  def __repr__(self):\n",
        "    return network_repr(self)\n",
        "\n",
        "def network_repr(network):\n",
        "  work = network.work\n",
        "  bandwidth = network.bandwidth\n",
        "  if hasattr(work, 'numpy'):\n",
        "    work = ' '.join(str(work.numpy()).split())\n",
        "  if hasattr(bandwidth, 'numpy'):\n",
        "    bandwidth = ' '.join(str(bandwidth.numpy()).split())\n",
        "  return (f\"<Network shape={network.shape} work={work} bandwidth={bandwidth}>\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NUUJe9HuIPel"
      },
      "outputs": [],
      "source": [
        "net1 = Network([5., 3, 8], [[0., 2, 0], [2, 0, 3], [0, 3, 0]])\n",
        "net2 = Network([3., 4, 2], [[0., 2, 2], [2, 0, 2], [2, 2, 0]])\n",
        "batch_of_networks = Network(\n",
        "    work=tf.stack([net1.work, net2.work]),\n",
        "    bandwidth=tf.stack([net1.bandwidth, net2.bandwidth]))\n",
        "print(f\"net1={net1}\")\n",
        "print(f\"net2={net2}\")\n",
        "print(f\"batch={batch_of_networks}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "r0qWur5JGc3d"
      },
      "source": [
        "You can then use `tf.data.Dataset` to iterate through a batch of networks:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BN_kixAUFZtv"
      },
      "outputs": [],
      "source": [
        "dataset = tf.data.Dataset.from_tensor_slices(batch_of_networks)\n",
        "for i, network in enumerate(dataset):\n",
        "  print(f\"Batch element {i}: {network}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aXENhTzIIjbM"
      },
      "source": [
        "And you can also use `map_fn` to apply a function to each batch element:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "j1XEsSWj9a3D"
      },
      "outputs": [],
      "source": [
        "def balance_work_greedy(network):\n",
        "  delta = (tf.expand_dims(network.work, -1) - tf.expand_dims(network.work, -2))\n",
        "  delta /= 4\n",
        "  delta = tf.maximum(tf.minimum(delta, network.bandwidth), -network.bandwidth)\n",
        "  new_work = network.work + tf.reduce_sum(delta, -1)\n",
        "  return Network(new_work, network.bandwidth)\n",
        "\n",
        "tf.map_fn(balance_work_greedy, batch_of_networks)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f_HLsTT02Xul"
      },
      "source": [
        "## TensorFlow APIs that support `ExtensionType`s"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NNiQad2U2alT"
      },
      "source": [
        "### @tf.function\n",
        "\n",
        "[`tf.function`](https://www.tensorflow.org/guide/function) is a decorator that precomputes TensorFlow graphs for Python functions, which can substantially improve the performance of your TensorFlow code. Extension type values can be used transparently with `@tf.function`-decorated functions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jQ_rAvrA6qEb"
      },
      "outputs": [],
      "source": [
        "class Pastry(tf.experimental.ExtensionType):\n",
        "  sweetness: tf.Tensor  # 2d embedding that encodes sweetness\n",
        "  chewiness: tf.Tensor  # 2d embedding that encodes chewiness\n",
        "\n",
        "@tf.function\n",
        "def combine_pastry_features(x: Pastry):\n",
        "  return (x.sweetness + x.chewiness) / 2\n",
        "\n",
        "cookie = Pastry(sweetness=[1.2, 0.4], chewiness=[0.8, 0.2])\n",
        "combine_pastry_features(cookie)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "u1P-0Udg71Vx"
      },
      "source": [
        "If you wish to explicitly specify the `input_signature` for  `tf.function`, then you can do so using the extension type's `TypeSpec`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "0df90E4x78d7"
      },
      "outputs": [],
      "source": [
        "pastry_spec = Pastry.Spec(tf.TensorSpec([2]), tf.TensorSpec(2))\n",
        "\n",
        "@tf.function(input_signature=[pastry_spec])\n",
        "def increase_sweetness(x: Pastry, delta=1.0):\n",
        "  return Pastry(x.sweetness + delta, x.chewiness)\n",
        "\n",
        "increase_sweetness(cookie)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CdTfc5nD9JpD"
      },
      "source": [
        "#### Concrete functions\n",
        "Concrete functions encapsulate individual traced graphs that are built by `tf.function`. Extension types can be used transparently with concrete functions.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FyHBBQWk9xz2"
      },
      "outputs": [],
      "source": [
        "cf = combine_pastry_features.get_concrete_function(pastry_spec)\n",
        "cf(cookie)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "LYas8gtG5IMA"
      },
      "source": [
        "### Control flow operations\n",
        "\n",
        "Extension types are supported by TensorFlow's control-flow operations:\n",
        "\n",
        "* `tf.cond`\n",
        "* `tf.case`\n",
        "* `tf.while_loop`\n",
        "* `tf.identity`\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6G2XE9ZtJu8z"
      },
      "outputs": [],
      "source": [
        "# Example: using tf.cond to select between two MaskedTensors. Note that the\n",
        "# two MaskedTensors don't need to have the same shape.\n",
        "a = MaskedTensor([1., 2, 3], [True, False, True])\n",
        "b = MaskedTensor([22., 33, 108, 55], [True, True, True, False])\n",
        "condition = tf.constant(True)\n",
        "print(tf.cond(condition, lambda: a, lambda: b))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2NwLOw1kKSek"
      },
      "outputs": [],
      "source": [
        "# Example: using tf.while_loop with MaskedTensor.\n",
        "cond = lambda i, _: i < 10\n",
        "def body(i, mt):\n",
        "  return i + 1, mt.with_values(mt.values + 3 / 7)\n",
        "print(tf.while_loop(cond, body, [0, b])[1])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zkN7IuWVMRzn"
      },
      "source": [
        "### Autograph control flow\n",
        "\n",
        "Extension types are also supported by control flow statements in `tf.function` (using autograph). In the following example, the `if` statement and `for` statements are automatically converted to `tf.cond` and `tf.while_loop` operations, which support extension types."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4RFySEl8gZ8w"
      },
      "outputs": [],
      "source": [
        "@tf.function\n",
        "def fn(x, b):\n",
        "  if b:\n",
        "    x = MaskedTensor(x, tf.less(x, 0))\n",
        "  else:\n",
        "    x = MaskedTensor(x, tf.greater(x, 0))\n",
        "  for i in tf.range(5 if b else 7):\n",
        "    x = x.with_values(x.values + 1 / 2)\n",
        "  return x\n",
        "\n",
        "print(fn(tf.constant([1., -2, 3]), tf.constant(True)))\n",
        "print(fn(tf.constant([1., -2, 3]), tf.constant(False)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-FjZt2ohfja4"
      },
      "source": [
        "### Keras\n",
        "\n",
        "[tf.keras](https://www.tensorflow.org/guide/keras) is TensorFlow's high-level API for building and training deep learning models. Extension types may be passed as inputs to a Keras model, passed between Keras layers, and returned by Keras models. Keras currently puts two requirements on extension types:\n",
        "\n",
        "* They must be batchable (go to \"Batchable `ExtensionType`s\" above).\n",
        "* They must have a field or property named `shape`. `shape[0]` is assumed to be the batch dimension.\n",
        "\n",
        "The following two subsections give examples showing how extension types can be used with Keras.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QH1TXQYiGv8u"
      },
      "source": [
        "#### Keras example: `Network`\n",
        "\n",
        "For the first example, consider the `Network` class defined in the \"Batchable `ExtensionType`s\" section above, which can be used for load balancing work between nodes. Its definition is repeated here:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zHj1RIS2PK50"
      },
      "outputs": [],
      "source": [
        "class Network(tf.experimental.BatchableExtensionType):\n",
        "  shape: tf.TensorShape  # batch shape. A single network has shape=[].\n",
        "  work: tf.Tensor        # work[*shape, n] = work left to do at node n\n",
        "  bandwidth: tf.Tensor   # bandwidth[*shape, n1, n2] = bandwidth from n1->n2\n",
        "\n",
        "  def __init__(self, work, bandwidth):\n",
        "    self.work = tf.convert_to_tensor(work)\n",
        "    self.bandwidth = tf.convert_to_tensor(bandwidth)\n",
        "    work_batch_shape = self.work.shape[:-1]\n",
        "    bandwidth_batch_shape = self.bandwidth.shape[:-2]\n",
        "    self.shape = work_batch_shape.merge_with(bandwidth_batch_shape)\n",
        "\n",
        "  def __repr__(self):\n",
        "    return network_repr(self)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "w9LPTEVJD0FD"
      },
      "outputs": [],
      "source": [
        "single_network = Network(  # A single network with 4 nodes.\n",
        "    work=[8.0, 5, 12, 2],\n",
        "    bandwidth=[[0.0, 1, 2, 2], [1, 0, 0, 2], [2, 0, 0, 1], [2, 2, 1, 0]])\n",
        "\n",
        "batch_of_networks = Network(  # Batch of 2 networks, each w/ 2 nodes.\n",
        "    work=[[8.0, 5], [3, 2]],\n",
        "    bandwidth=[[[0.0, 1], [1, 0]], [[0, 2], [2, 0]]])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IUfWi3SDD0dj"
      },
      "source": [
        "You can define a new Keras layer that processes `Network`s."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2WSYt58r4SF1"
      },
      "outputs": [],
      "source": [
        "class BalanceNetworkLayer(tf.keras.layers.Layer):\n",
        "  \"\"\"Layer that balances work between nodes in a network.\n",
        "\n",
        "  Shifts work from more busy nodes to less busy nodes, constrained by bandwidth.\n",
        "  \"\"\"\n",
        "  def call(self, inputs):\n",
        "    # This function is defined above in the \"Batchable `ExtensionType`s\" section.\n",
        "    return balance_work_greedy(inputs)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VWwFJNb1E03q"
      },
      "source": [
        "You can then use these layers to create a simple model. To feed an `ExtensionType` into a model, you can use a `tf.keras.layer.Input` layer with `type_spec` set to the extension type's `TypeSpec`. If the Keras model will be used to process batches, then the `type_spec` must include the batch dimension."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "plTyqISRExA4"
      },
      "outputs": [],
      "source": [
        "input_spec = Network.Spec(shape=None,\n",
        "                          work=tf.TensorSpec(None, tf.float32),\n",
        "                          bandwidth=tf.TensorSpec(None, tf.float32))\n",
        "model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Input(type_spec=input_spec),\n",
        "    BalanceNetworkLayer(),\n",
        "    ])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hyeAbt1WFIiO"
      },
      "source": [
        "Finally, you can apply the model to a single network and to a batch of networks."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "hH1EtA5lFHdN"
      },
      "outputs": [],
      "source": [
        "model(single_network)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "V7eM67M7FYYM"
      },
      "outputs": [],
      "source": [
        "model(batch_of_networks)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "tOxtt9Z1HDCv"
      },
      "source": [
        "#### Keras example: MaskedTensor\n",
        "\n",
        "In this example, `MaskedTensor` is extended to support `Keras`. `shape` is defined as a property that is calculated from the `values` field. Keras requires that you add this property to both the extension type and its `TypeSpec`. `MaskedTensor` also defines a `__name__` variable, which will be required for `SavedModel` serialization (below)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1JBZ_t48Ht7e"
      },
      "outputs": [],
      "source": [
        "class MaskedTensor(tf.experimental.BatchableExtensionType):\n",
        "  # __name__ is required for serialization in SavedModel; see below for details.\n",
        "  __name__ = 'extension_type_colab.MaskedTensor'\n",
        "\n",
        "  values: tf.Tensor\n",
        "  mask: tf.Tensor\n",
        "\n",
        "  shape = property(lambda self: self.values.shape)\n",
        "  dtype = property(lambda self: self.values.dtype)\n",
        "\n",
        "  def with_default(self, default):\n",
        "    return tf.where(self.mask, self.values, default)\n",
        "\n",
        "  def __repr__(self):\n",
        "    return masked_tensor_str(self.values, self.mask)\n",
        "\n",
        "  class Spec:\n",
        "    def __init__(self, shape, dtype=tf.float32):\n",
        "      self.values = tf.TensorSpec(shape, dtype)\n",
        "      self.mask = tf.TensorSpec(shape, tf.bool)\n",
        "\n",
        "    shape = property(lambda self: self.values.shape)\n",
        "    dtype = property(lambda self: self.values.dtype)\n",
        "\n",
        "    def with_shape(self):\n",
        "      return MaskedTensor.Spec(tf.TensorSpec(shape, self.values.dtype),\n",
        "                               tf.TensorSpec(shape, self.mask.dtype))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oer8BVc8H7_V"
      },
      "source": [
        "Next, the dispatch decorators are used to override the default behavior of several TensorFlow APIs. Since these APIs are used by standard Keras layers (such as the `Dense` layer), overriding these will allow us to use those layers with `MaskedTensor`. For the purposes of this example, `matmul` for masked tensors is defined to treat the masked values as zeros (that is, to not include them in the product)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xy0dhQ_b-ca_"
      },
      "outputs": [],
      "source": [
        "@tf.experimental.dispatch_for_unary_elementwise_apis(MaskedTensor)\n",
        "def unary_elementwise_op_handler(op, x):\n",
        " return MaskedTensor(op(x.values), x.mask)\n",
        "\n",
        "@tf.experimental.dispatch_for_binary_elementwise_apis(\n",
        "    Union[MaskedTensor, tf.Tensor],\n",
        "    Union[MaskedTensor, tf.Tensor])\n",
        "def binary_elementwise_op_handler(op, x, y):\n",
        "  x = convert_to_masked_tensor(x)\n",
        "  y = convert_to_masked_tensor(y)\n",
        "  return MaskedTensor(op(x.values, y.values), x.mask & y.mask)\n",
        "\n",
        "@tf.experimental.dispatch_for_api(tf.matmul)\n",
        "def masked_matmul(a: MaskedTensor, b,\n",
        "                  transpose_a=False, transpose_b=False,\n",
        "                  adjoint_a=False, adjoint_b=False,\n",
        "                  a_is_sparse=False, b_is_sparse=False,\n",
        "                  output_type=None,\n",
        "                  grad_a=False, grad_b=False,\n",
        "                  name=None,\n",
        "                  ):\n",
        "  if isinstance(a, MaskedTensor):\n",
        "    a = a.with_default(0)\n",
        "  if isinstance(b, MaskedTensor):\n",
        "    b = b.with_default(0)\n",
        "  return tf.matmul(a, b, transpose_a, transpose_b, adjoint_a,\n",
        "                   adjoint_b, a_is_sparse, b_is_sparse,\n",
        "                   output_type)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "osJ_L-fKJusI"
      },
      "source": [
        "You can then construct a Keras model that accepts `MaskedTensor` inputs, using standard Keras layers:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "IS6JCVbk1rd0"
      },
      "outputs": [],
      "source": [
        "input_spec = MaskedTensor.Spec([None, 2], tf.float32)\n",
        "\n",
        "masked_tensor_model = tf.keras.Sequential([\n",
        "    tf.keras.layers.Input(type_spec=input_spec),\n",
        "    tf.keras.layers.Dense(16, activation=\"relu\"),\n",
        "    tf.keras.layers.Dense(1)])\n",
        "masked_tensor_model.compile(loss='binary_crossentropy', optimizer='rmsprop')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "SB1WUSzn1RPj"
      },
      "outputs": [],
      "source": [
        "a = MaskedTensor([[1., 2], [3, 4], [5, 6]],\n",
        "                  [[True, False], [False, True], [True, True]])\n",
        "masked_tensor_model.fit(a, tf.constant([[1], [0], [1]]), epochs=3)\n",
        "print(masked_tensor_model(a))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "msmd9XcL2bqb"
      },
      "source": [
        "### SavedModel\n",
        "\n",
        "A [SavedModel](https://www.tensorflow.org/guide/saved_model) is a serialized TensorFlow program, including both weights and computation. It can be built from a Keras model or from a custom model. In either case, extension types can be used transparently with the functions and methods defined by a SavedModel.\n",
        "\n",
        "SavedModel can save models, layers, and functions that process extension types, as long as the extension types have a `__name__` field. This name is used to register the extension type, so it can be located when the model is loaded."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "PEtbFrz6-Vku"
      },
      "source": [
        "#### Example: saving a Keras model\n",
        "\n",
        "Keras models that use extension types may be saved using `SavedModel`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ecxQMnybSzV6"
      },
      "outputs": [],
      "source": [
        "masked_tensor_model_path = tempfile.mkdtemp()\n",
        "tf.saved_model.save(masked_tensor_model, masked_tensor_model_path)\n",
        "imported_model = tf.saved_model.load(masked_tensor_model_path)\n",
        "imported_model(a)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Ne2nu3r6-XMr"
      },
      "source": [
        "#### Example: saving a custom model\n",
        "\n",
        "SavedModel can also be used to save custom `tf.Module` subclasses with functions that process extension types."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2V6hV3yOT2vz"
      },
      "outputs": [],
      "source": [
        "class CustomModule(tf.Module):\n",
        "  def __init__(self, variable_value):\n",
        "    super().__init__()\n",
        "    self.v = tf.Variable(variable_value)\n",
        "\n",
        "  @tf.function\n",
        "  def grow(self, x: MaskedTensor):\n",
        "    \"\"\"Increase values in `x` by multiplying them by `self.v`.\"\"\"\n",
        "    return MaskedTensor(x.values * self.v, x.mask)\n",
        "\n",
        "module = CustomModule(100.0)\n",
        "\n",
        "module.grow.get_concrete_function(MaskedTensor.Spec(shape=None,\n",
        "                                                    dtype=tf.float32))\n",
        "custom_module_path = tempfile.mkdtemp()\n",
        "tf.saved_model.save(module, custom_module_path)\n",
        "imported_model = tf.saved_model.load(custom_module_path)\n",
        "imported_model.grow(MaskedTensor([1., 2, 3], [False, True, False]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "o6beljh576ee"
      },
      "source": [
        "#### Loading a SavedModel when the `ExtensionType` is unavailable\n",
        "\n",
        "If you load a `SavedModel` that uses an `ExtensionType`, but that `ExtensionType` is not available (that is, it has not been imported), then you will get a warning and TensorFlow will fall back to using an \"anonymous extension type\" object. This object will have the same fields as the original type, but will lack any further customization you have added for the type, such as custom methods or properties."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ec9PcUkJ9bFK"
      },
      "source": [
        "#### Using `ExtensionType`s with TensorFlow Serving\n",
        "\n",
        "Currently, [TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) (and other consumers of the SavedModel \"signatures\" dictionary) require that all inputs and outputs be raw tensors. If you wish to use TensorFlow Serving with a model that uses extension types, then you can add wrapper methods that compose or decompose extension type values from tensors. For example:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4VnzAwVo9tTc"
      },
      "outputs": [],
      "source": [
        "class CustomModuleWrapper(tf.Module):\n",
        "  def __init__(self, variable_value):\n",
        "    super().__init__()\n",
        "    self.v = tf.Variable(variable_value)\n",
        "\n",
        "  @tf.function\n",
        "  def var_weighted_mean(self, x: MaskedTensor):\n",
        "    \"\"\"Mean value of unmasked values in x, weighted by self.v.\"\"\"\n",
        "    x = MaskedTensor(x.values * self.v, x.mask)\n",
        "    return (tf.reduce_sum(x.with_default(0)) /\n",
        "            tf.reduce_sum(tf.cast(x.mask, x.dtype)))\n",
        "\n",
        "  @tf.function()\n",
        "  def var_weighted_mean_wrapper(self, x_values, x_mask):\n",
        "    \"\"\"Raw tensor wrapper for var_weighted_mean.\"\"\"\n",
        "    return self.var_weighted_mean(MaskedTensor(x_values, x_mask))\n",
        "\n",
        "module = CustomModuleWrapper([3., 2., 8., 5.])\n",
        "\n",
        "module.var_weighted_mean_wrapper.get_concrete_function(\n",
        "    tf.TensorSpec(None, tf.float32), tf.TensorSpec(None, tf.bool))\n",
        "custom_module_path = tempfile.mkdtemp()\n",
        "tf.saved_model.save(module, custom_module_path)\n",
        "imported_model = tf.saved_model.load(custom_module_path)\n",
        "x = MaskedTensor([1., 2., 3., 4.], [False, True, False, True])\n",
        "imported_model.var_weighted_mean_wrapper(x.values, x.mask)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4dwBadWQ5G9_"
      },
      "source": [
        "### `Dataset`s\n",
        "\n",
        "[`tf.data`](https://www.tensorflow.org/guide/data) is an API that enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is `tf.data.Dataset`, which represents a sequence of elements, in which each element consists of one or more components."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "GcIR19FuwRJV"
      },
      "source": [
        "#### Building `Dataset`s with extension types\n",
        "\n",
        "Datasets can be built from extension type values using `Dataset.from_tensors`, `Dataset.from_tensor_slices`, or `Dataset.from_generator`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Oe7fRCkzwdub"
      },
      "outputs": [],
      "source": [
        "ds = tf.data.Dataset.from_tensors(Pastry(5, 5))\n",
        "iter(ds).next()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fk9CD2fZx6yT"
      },
      "outputs": [],
      "source": [
        "mt = MaskedTensor(tf.reshape(range(20), [5, 4]), tf.ones([5, 4]))\n",
        "ds = tf.data.Dataset.from_tensor_slices(mt)\n",
        "for value in ds:\n",
        "  print(value)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "DGw8y87awsOJ"
      },
      "outputs": [],
      "source": [
        "def value_gen():\n",
        "  for i in range(2, 7):\n",
        "    yield MaskedTensor(range(10), [j%i != 0 for j in range(10)])\n",
        "\n",
        "ds = tf.data.Dataset.from_generator(\n",
        "    value_gen, output_signature=MaskedTensor.Spec(shape=[10], dtype=tf.int32))\n",
        "for value in ds:\n",
        "  print(value)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wfEm4NInyqtj"
      },
      "source": [
        "#### Batching and unbatching `Dataset`s with extension types\n",
        "\n",
        "Datasets with extension types can be batchand and unbatched using `Dataset.batch` and `Dataset.unbatch`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "snoOUE1ay1rO"
      },
      "outputs": [],
      "source": [
        "batched_ds = ds.batch(2)\n",
        "for value in batched_ds:\n",
        "  print(value)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f8PTky6EzBVY"
      },
      "outputs": [],
      "source": [
        "unbatched_ds = batched_ds.unbatch()\n",
        "for value in unbatched_ds:\n",
        "  print(value)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "extension_type.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
